package face6wap;

import face5wap.LS_WGAN;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import util.ShowUtils;
import util.ShowUtilsT;

import java.io.File;
import java.util.Arrays;
import java.util.Map;

/**
 * 说明：
	该方法是由 Gan_one变化而来，说明了ComputationGraph 中dis的拆分和参数的更新
 */
public class Gan_two_dis {
	private static final Logger log = LoggerFactory.getLogger(Gan_two_dis.class);

	static double lr = 0.002;
	static String model = "F:/face/gan.zip";
	public static void main(String[] args) throws Exception {
		Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
		final GraphBuilder ganModel = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"g2")
			/*	.addVertex("stack", new StackVertex(), "input2", "g3")*/
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"g3")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");
		//拆分出的dis
		final GraphBuilder disModel = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1")
				.addLayer("d1",new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"input1")
				.addLayer("d2",new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"d1")
				.addLayer("d3",new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");

		ComputationGraph net = new ComputationGraph(ganModel.build());
		ComputationGraph dis = new ComputationGraph(disModel.build());
		/*if (new File(model).exists()) {
			net = ComputationGraph.load(new File(model), true);
		}else{*/

		/*}*/
		net.init();
		dis.init();
		System.out.println(net.summary());
		System.out.println(dis.summary());
		/*UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);*/
		dis.setListeners(new ScoreIterationListener(100));




		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
		INDArray labelG = Nd4j.ones(60, 1);
		INDArray realLabel =  Nd4j.zeros(30, 1);
		INDArray fakeLabel = Nd4j.ones(30, 1);
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();


			/*Map<String, INDArray> mapfake = net.feedForward(
					new INDArray[] {z}, false);
			INDArray fake = mapfake.get("g3");// .reshape(20,28,28);*/
			/*INDArray fakeNoise = Nd4j.rand(new NormalDistribution(),new long[] { 30, 28 * 28 });*/
		/*	DataSet realSet = new DataSet(trueExp, realLabel);
			DataSet fakeSet = new DataSet(fake,fakeLabel);*/
			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			Map<String, INDArray> mapfake = net.feedForward(
					new INDArray[] {z}, false);
			INDArray fake = mapfake.get("g3");// .reshape(20,28,28);
			MultiDataSet realSet = new MultiDataSet(new INDArray[] {trueExp},
					new INDArray[] {realLabel });
			MultiDataSet fakeSet = new MultiDataSet(new INDArray[] {fake},
					new INDArray[] { fakeLabel });
			//log.info("开始训练dis");
			for(int m=0;m<5 ;m++){

				dis.fit(realSet);
				dis.fit(fakeSet);
				//trainD(net, dataSetD);
			}
			Nd4j.clearNans(fake);
			//log.info("结束训练dis");
			//System.out.println(net.getLayer("d1").getParam("W").sub(dis.getLayer("d1").getParam("W")).toString() );
			net.getLayer("d1").setParam("W", dis.getLayer("d1").getParam("W"));
			net.getLayer("d1").setParam("b", dis.getLayer("d1").getParam("b"));
			net.getLayer("d2").setParam("W", dis.getLayer("d2").getParam("W"));
			net.getLayer("d2").setParam("b", dis.getLayer("d2").getParam("b"));
			net.getLayer("d3").setParam("W", dis.getLayer("d3").getParam("W"));
			net.getLayer("d3").setParam("b", dis.getLayer("d3").getParam("b"));
			net.getLayer("out").setParam("W", dis.getLayer("out").getParam("W"));
			net.getLayer("out").setParam("b", dis.getLayer("out").getParam("b"));
			//System.out.println(net.getLayer("d1").getParam("W").sub(dis.getLayer("d1").getParam("W")).toString() );
			//System.out.println("--------------------------------------------------");
			//log.info("结束dis-net赋值");
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			/*MultiDataSet dataSetG = new MultiDataSet(new INDArray[] {z},
					new INDArray[] { fakeLabel });*/
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z},
					new INDArray[] { fakeLabel });
			trainG(net,dataSetG);
			//net.fit(dataSetG);
			Nd4j.clearNans(z);
			//trainG(net, dataSetG);
			if (i % 50 == 0) {

				INDArray noise =  Nd4j.rand(new NormalDistribution(),new long[] { 50, 10 });
				/*INDArray[] samps =  gen.output(noise);*/
				/*long[] shpaes = samps[0].shape();
				INDArray[] samples = new INDArray[(int)samps.length];
				for (int j = 0; j < samps.length; j++) {
					samples[j] = samps[j];
				}*/
				Map<String, INDArray> map1 = net.feedForward(
						new INDArray[] {noise}, false);
				INDArray indArray2 = map1.get("g3");// .reshape(20,28,28);
				INDArray[] samples = new INDArray[(int)indArray2.size(0)];

				samples[0] = indArray2;

				ShowUtils.visualize(samples,"拆分");

			}


			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}
 
	}

	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
}